import os
import sys
import json
import hashlib
import pymongo
import argparse
import subprocess
import multiprocessing
from avclass_common3 import AvLabels
from avclass_common3_mod import AvLabels as AvLabels_mod

def get_md5_token_counts(file_path, avlabels):
    md5_token_counts = []
    with open(file_path, "r") as f:
        for line in f:
            try:
                vt_dict = json.loads(line)
            except json.decoder.JSONDecodeError:
                continue
            if vt_dict.get("md5") is None:
                continue
            md5 = vt_dict["md5"]
            sample_info = avlabels.get_sample_info(vt_dict, True)
            if sample_info is None:
                continue
            token_counts = dict(avlabels.get_family_ranking(sample_info))
            md5_token_counts.append([md5, list(token_counts.items())])
    return md5_token_counts


def update_db(all_md5_token_counts, thresholds, db):
    for md5_token_counts, label_type in all_md5_token_counts:
        num_labeled = 0
        to_insert = []
        collection = db[label_type]
        if collection.count():
            print("[+] Dropping collection {}".format(args.collection_name))
            collection.drop()
        for md5, token_counts in md5_token_counts:
            labels = {"md5": md5}
            if label_type == "av_majority_vote":
                labels[label_type] = md5
                if len(token_counts):
                    token_count_sum = sum([token_counts[i][1] for i in range(len(token_counts))])
                    max_count = token_counts[0][1]
                    if max_count >=4 and token_count_sum <= 7:
                        labels[label_type] = token_counts[0][0]
            else:
                diff = None
                if len(token_counts) == 1:
                    diff = token_counts[0][1]
                if len(token_counts) >= 2:
                    diff = token_counts[0][1] - token_counts[1][1]
                for threshold in thresholds:
                    temp_label_type = "{}_{}".format(label_type, str(threshold))
                    if diff is not None and diff >= threshold:
                        labels[temp_label_type] = token_counts[0][0]
                    else:
                        labels[temp_label_type] = md5
            to_insert.append(labels)
            num_labeled += 1
            if num_labeled % 1000 == 0:
                db[label_type].insert_many(to_insert, ordered=False)
                to_insert = []
                print(label_type, num_labeled)
                sys.stdout.flush()

        if len(to_insert):
            db[label_type].insert_many(to_insert, ordered=False)
            to_insert = []
            print(label_type, num_labeled)
            sys.stdout.flush()

    return labels


def handle_avclass_mode():

    print("[-] Labeling using AVClass")
    sys.stdout.flush()
    labeler = os.path.join(args.avclass_dir, "avclass_labeler.py")
    aliaser = os.path.join(args.avclass_dir, "avclass_alias_detect.py")
    gen_path = os.path.join(args.avclass_dir, "data/default.generics")
    alias_default_path = os.path.join(args.avclass_dir, "data/default.aliases")
    alias_prep_path = os.path.join(args.avclass_dir, "data/prep.aliases")
    alias_prep_strict_path = os.path.join(args.avclass_dir, "data/prep_strict.aliases")
    genav_path = os.path.join(args.avclass_dir, "data/default.genav")
    av_path = os.path.join(args.avclass_dir, "data/default.av")

    for file_name in sorted(os.listdir(args.label_dir)):

        file_path = os.path.join(args.label_dir, file_name)

        # Generate new alias and gentoken files
        print("[-] Preparing default aliases: {}".format(file_path))
        sys.stdout.flush()
        if os.path.exists(alias_prep_path):
            os.remove(alias_prep_path)
        cmd = "{} -vt {} -nalias 20 -talias 0.94 > {}"
        cmd = cmd.format(aliaser, file_path, alias_prep_path)
        with open(os.devnull, "w") as f:
            os.chdir(args.avclass_dir)
            p = subprocess.Popen(cmd, shell=True, stderr=f)
            p.wait()

        # Generate new alias and gentoken files
        print("[-] Preparing strict aliases: {}".format(file_path))
        sys.stdout.flush()
        if os.path.exists(alias_prep_strict_path):
            os.remove(alias_prep_strict_path)
        cmd = "{} -vt {} -nalias 100 -talias 0.98 > {}"
        cmd = cmd.format(aliaser, file_path, alias_prep_strict_path)
        with open(os.devnull, "w") as f:
            os.chdir(args.avclass_dir)
            p = subprocess.Popen(cmd, shell=True, stderr=f)
            p.wait()

        # Get AVClass labels
        print("[-] Getting AVClass labels: {}".format(file_path))
        sys.stdout.flush()
        all_md5_token_counts = []
        avlabels = AvLabels(gen_file=gen_path, alias_file=alias_default_path)
        all_md5_token_counts.append([get_md5_token_counts(file_path, avlabels),
                                     "avclass_default_prep"])
        avlabels = AvLabels(gen_file=gen_path, alias_file=alias_prep_path)
        all_md5_token_counts.append([get_md5_token_counts(file_path, avlabels),
                                     "avclass_alias_prep"])
        avlabels = AvLabels(gen_file=gen_path, alias_file=alias_prep_strict_path)
        all_md5_token_counts.append([get_md5_token_counts(file_path, avlabels),
                                     "avclass_alias_prep_strict"])
        avlabels=AvLabels_mod(gen_file=gen_path, alias_file=alias_prep_path,
                              genav_file=genav_path)
        all_md5_token_counts.append([get_md5_token_counts(file_path, avlabels),
                                     "avclass_genav_remove"])
        avlabels = AvLabels(gen_file=gen_path, alias_file=alias_default_path,
                            av_file=av_path)
        all_md5_token_counts.append([get_md5_token_counts(file_path, avlabels),
                                     "avclass_av_majority_vote"])

        # Update db
        print("[-] Updating db: {}".format(file_path))
        sys.stdout.flush()
        client = pymongo.MongoClient("127.0.0.1", 27017)
        db = client[args.db_name]
        update_db(all_md5_token_counts, [0, 1, 2, 3, 4, 5], db)
        client.close()
    return


def get_dir_label(file_path):
    file_name = os.path.basename(file_path)
    md5 = file_name.split("_")
    if len(md5) == 2:
        md5 = md5[1].lower()
    else:
        md5 = hashlib.md5(open(file_path, "rb").read()).hexdigest()

    dir_label = {
        "md5": md5,
        "label": os.path.basename(os.path.dirname(file_path))
    }
    return dir_label


def handle_dir_mode():
    label_file_paths = set()
    for path, subdirectories, file_names in os.walk(args.label_dir):
        for file_name in file_names:
            file_path = os.path.join(path, file_name)
            if os.path.isfile(file_path):
                label_file_paths.add(file_path)
    file_paths = sorted(label_file_paths)

    # Split file paths into batches
    batch_size = 1000
    num_files = len(file_paths)
    batches = [file_paths[i * batch_size:(i + 1) * batch_size]
               for i in range((num_files + batch_size - 1) // batch_size)]

    # Get dir labels and insert into db
    total_processed = 0
    pool = multiprocessing.Pool(12)
    for batch in batches:
        all_dir_labels = pool.map(get_dir_label, batch)
        client = pymongo.MongoClient("127.0.0.1", 27017)
        db = client[args.db_name]
        db["dir_labels"].insert_many(all_dir_labels, ordered=False)
        total_processed += len(all_dir_labels)
        print("[-] Processed {} malware samples".format(total_processed))
        sys.stdout.flush()
        client.close()
    pool.close()
    pool.join()

    return


if __name__ == "__main__":

    # Parse command line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("--label-dir",
                        default="/media/data1/labels/")
    parser.add_argument("--db-name", default="agtr_db")
    parser.add_argument("--drop-database", default=False, action="store_true")
    parser.add_argument("--collection-name", default="avclass_labels")

    subparsers = parser.add_subparsers(title="modes", dest="mode")

    avclass_parser = subparsers.add_parser("avclass")
    avclass_parser.add_argument("--avclass-dir", default="avclass/")

    dir_parser = subparsers.add_parser("dir")

    args = parser.parse_args()

    # Re-initialize db table
    client = pymongo.MongoClient("127.0.0.1", 27017)
    if args.db_name in client.list_database_names() and args.drop_database:
        print("[+] Dropping database: {}".format(args.db_name))
        client.drop_database(args.db_name)
    db = client[args.db_name]
    collection = db[args.collection_name]
    if collection.count():
        print("[+] Dropping collection {}".format(args.collection_name))
        collection.drop()
    collection.create_index([("md5", pymongo.HASHED)])
    print("[+] Created database collection {}".format(args.collection_name))
    client.close()

    # Handle current mode
    if args.mode == "avclass":
        handle_avclass_mode()
    elif args.mode == "dir":
        handle_dir_mode()
